import os
import sys
from collections import defaultdict, Counter
import pysam
from numpy import *
from pylab import *


dataset = sys.argv[1]


directory = "/osc-fs_home/mdehoon/Data/CASPARs"
subdirectory = os.path.join(directory, dataset, "Mapping")
libraries = []
for filename in os.listdir(subdirectory):
    library, extension = os.path.splitext(filename)
    assert extension == ".bam"
    if dataset == "MiSeq" and library == 'neg_r1':
        print("Skipping negative control library neg_r1 prepared without the 3' linker")
    else:
        libraries.append(library)
libraries.sort()
if dataset == "MiSeq":
    assert len(libraries) == 35
else:
    raise Exception("Unexpected dataset '%s'" % dataset)

maxlength = 1000
counts = defaultdict(lambda: zeros(maxlength+1))
for library in libraries:
    filename = "%s.bam" % library
    path = os.path.join(subdirectory, filename)
    sys.stderr.write("Reading %s\n" % path)
    alignments = pysam.AlignmentFile(path, "rb")
    for alignment1 in alignments:
        alignment2 = next(alignments)
        try:
            target = alignment1.get_tag("XT")
        except KeyError:
            assert alignment1.is_unmapped
            assert alignment2.is_unmapped
            continue
        if target in ("mRNA", "lncRNA", "gencode", "fantomcat",
                      "novel", "genome", 'TERC', 'MALAT1', 'snhg'):
            try:
                annotation = alignment1.get_tag("XA")
            except KeyError:
                annotation = "other_intergenic"
        else:
            annotation = target  # known small RNAs
        if target in ("snRNA", "scRNA", "snoRNA", "scaRNA"):
            transcripts = alignment1.get_tag("XR")
        elif target in ("mRNA", "lncRNA", "gencode", "fantomcat"):
            transcripts = None
        else:
            try:
                transcripts = alignment1.get_tag("XR")
            except KeyError:
                transcripts = None
            else:
                raise Exception(target)
        if transcripts is None:
            transcripts = ['-']
        else:
            transcripts = transcripts.split(",")
        if alignment1.is_unmapped:
            count = 1.0
        else:
            multimap = alignment1.get_tag("NH")
            count = 1 / multimap
        if target == 'genome':
            if alignment1.is_reverse:
                assert not alignment2.is_reverse
                start = alignment2.reference_start
                end = alignment1.reference_end
            else:
                assert alignment2.is_reverse
                start = alignment1.reference_start
                end = alignment2.reference_end
            assert start < end
            length = end - start
        else:
            length = alignment1.get_tag("XL")
        if length >= maxlength:
            length = maxlength
        count /= len(transcripts)
        for transcript in transcripts:
            key = (annotation, transcript)
            counts[key][length] += count


annotation_counts = Counter()
transcript_counts = Counter()
for key in counts:
    annotation, transcript = key
    annotation_counts[annotation] += sum(counts[key])
    transcript_counts[key] += sum(counts[key])

key_counts = {}
for key in counts:
    annotation, transcript = key
    key_counts[key] = (annotation_counts[annotation], transcript_counts[key])

keys = sorted(counts, key=key_counts.get, reverse=True)

filename = "rnasize.%s.txt" % dataset
print("Saving RNA sizes to %s" % filename)
stream = open(filename, 'wt')
terms = ["#rank", "annotation", "transcript"]
for length in range(maxlength):
    terms.append(str(length))
terms.append(">%d" % (maxlength-1))
line = "\t".join(terms) + "\n"
stream.write(line)
for i, key in enumerate(keys):
    annotation, transcript = key
    if transcript is None:
        transcript = "-"
    terms = ["%d" % i, "%30s" % annotation, "%20s" % transcript]
    for count in counts[key]:
        terms.append(str(count))
    line = "\t".join(terms) + "\n"
    stream.write(line)
stream.close()

